# -*- coding:utf-8 -*- 
# SyscallDef module:  structure and build function for syscalls 
from enum import Enum
from ropgenerator.exploit.HighLevelUtils import popMultiple, build_call
from ropgenerator.IO import verbose, string_bold, string_ropg, string_payload, error, string_special
from ropgenerator.semantic.Engine import search, getBaseAssertion
from ropgenerator.Database import QueryType
from ropgenerator.Constraints import Constraint, Assertion
import ropgenerator.Architecture as Arch

class ArgType(Enum):
    INT="int"
    STRING="str"
    INT_OR_STRING="int or str"
    
def verifyArgType(arg, argType):
    if( argType == ArgType.INT ):
        return isinstance(arg, int)
    elif( argType == ArgType.STRING ):
        return isinstance(arg,str)
    elif( argType == ArgType.INT_OR_STRING):
        return isinstance(arg,str) or isinstance(arg, int)
    else:
        return False

class Syscall:
    def __init__(self, retType, name, args, arg_types, arg_regs, syscall_arg_regs, function=None):
        if( not function):
            function = name
        self.ret = retType
        self.def_function = function 
        self.def_name = name
        self.def_args = args
        self.arg_regs = arg_regs 
        self.arg_types = arg_types
        self.syscall_arg_regs = syscall_arg_regs
        
        
    def __str__(self):
        res = self.ret + " " + string_bold(self.name())
        res += "("
        res += ', '.join([a[0] + " " + string_special(a[1]) for a in self.def_args])
        res += ")"
        return res
    
    def name(self):
        return self.def_name
        
    def function(self):
        return self.def_function
        
    def nb_args(self):
        return len(self.arg_regs)
    
SYSCALL_LMAX = 500

def build_syscall_Linux(syscall, arg_list, arch_bits, constraint=None, assertion = None, clmax=SYSCALL_LMAX, optimizeLen=False):
    """
    arch_bits = 32 or 64 :) 
    """
    # Check args
    if( syscall.nb_args() != len(arg_list)):
        error("Error. Expected {} arguments, got {}".format(len(syscall.arg_types), len(arg_list)))
        return None
    # Check args length 
    for i in range(0,len(arg_list)):
        if( not verifyArgType(arg_list[i], syscall.arg_types[i])):
            error("Argument error for '{}': expected '{}', got '{}'".format(arg_list[i], syscall.arg_types[i], type(arg_list[i])))
            return None 
    # Check constraint and assertion 
    if( constraint is None ):
        constraint = Constraint()
    if( assertion is None ):
        assertion = getBaseAssertion()
    
    # Check if we have the function ! 
    verbose("Trying to call {}() function directly".format(syscall.def_name))
    func_call = build_call(syscall.function(), arg_list, constraint, assertion, clmax=clmax, optimizeLen=optimizeLen)
    if( not isinstance(func_call, str) ):
        verbose("Success")
        return func_call
    else:
        if( not constraint.chainable.ret ):
            verbose("Coudn't call {}(), try direct syscall".format(syscall.def_name))
        else:
            verbose("Couldn't call {}() and return to ROPChain".format(syscall.def_name))
            return None
    
    # Otherwise do syscall directly
    # Set the registers
    args = [(Arch.n2r(x[0]), x[1]) for x in zip(syscall.arg_regs, arg_list) + syscall.syscall_arg_regs]
    chain = popMultiple(args, constraint, assertion, clmax-1, optimizeLen=optimizeLen)
    if( not chain ):
        verbose("Failed to set registers for the mprotect syscall")
        return None
    # Int 0x80
    if( arch_bits == 32 ):
        syscall_gadgets = search(QueryType.INT80, None, None, constraint, assertion)
    # syscall 
    elif( arch_bits == 64):
        syscall_gadgets = search(QueryType.SYSCALL, None, None, constraint, assertion)
    if( not syscall_gadgets ):
        verbose("Failed to find an 'int 0x80' OR 'syscall' gadget")
        return None
    else:
        chain.addChain(syscall_gadgets[0])
    verbose("Success")
    return chain
    
